{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For Colab only!\n",
    "\n",
    "try:\n",
    "  # %tensorflow_version only exists in Colab.\n",
    "  %tensorflow_version 2.x\n",
    "except Exception:\n",
    "  pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torchvision import datasets, transforms\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "cell_style": "split",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.1.0\n",
      "WARNING:tensorflow:From <ipython-input-3-2e82a26757ae>:2: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.config.list_physical_devices('GPU')` instead.\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(tf.__version__)\n",
    "print(tf.test.is_gpu_available())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.4.0\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)\n",
    "print(torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Loading\n",
    "MINST data set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "batch_size=200\n",
    "learning_rate=0.01\n",
    "epochs=10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [],
   "source": [
    "(x, y),(x_test, y_test) = keras.datasets.mnist.load_data()\n",
    "\n",
    "ds_train = tf.data.Dataset.from_tensor_slices((x,y))\n",
    "ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n",
    "\n",
    "def preprocess(x, y):\n",
    "  x = (tf.cast(x, tf.float32)/255)-0.1307\n",
    "  y = tf.cast(y, tf.int32)\n",
    "#   y = tf.one_hot(y,depth=10)   \n",
    "  return x, y\n",
    "\n",
    "ds_train = ds_train.map(preprocess).shuffle(1000).batch(batch_size)\n",
    "ds_test = ds_test.map(preprocess).shuffle(1000).batch(batch_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "cell_style": "split",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1307,), (0.3081,))\n",
    "                   ])),\n",
    "    batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=False, transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.1307,), (0.3081,))\n",
    "    ])),\n",
    "    batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>\n",
      "(200, 28, 28) (200,)\n"
     ]
    }
   ],
   "source": [
    "print(type(ds_test))\n",
    "image, label = next(iter(ds_test))\n",
    "print(image.shape, label.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "cell_style": "split",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.utils.data.dataloader.DataLoader'>\n",
      "torch.Size([200, 1, 28, 28]) torch.Size([200])\n"
     ]
    }
   ],
   "source": [
    "print(type(train_loader))\n",
    "[image, label] = next(iter(train_loader))\n",
    "print(image.shape, label.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Full-connnected manually\n",
    "[b, 786] -> [b, 200] -> [b, 100] -> [b, 10]\n",
    "\n",
    "w1: [786, 200], b1: [200],\n",
    "\n",
    "w2: [200,100], b2: [100], \n",
    "\n",
    "w3: [100,10], b3：[10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "cell_style": "split",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0, step:0 loss:1404.9921875\n",
      "epoch:0, step:100 loss:2.2279036045074463\n",
      "epoch:0, step:200 loss:2.3840603828430176\n",
      "epoch:1, step:0 loss:2.1420228481292725\n",
      "epoch:1, step:100 loss:2.0314340591430664\n",
      "epoch:1, step:200 loss:2.0380284786224365\n",
      "epoch:2, step:0 loss:2.127847671508789\n",
      "epoch:2, step:100 loss:2.021336793899536\n",
      "epoch:2, step:200 loss:1.6273704767227173\n",
      "epoch:3, step:0 loss:1.5105515718460083\n",
      "epoch:3, step:100 loss:1.5376299619674683\n",
      "epoch:3, step:200 loss:1.2730509042739868\n",
      "epoch:4, step:0 loss:1.2069252729415894\n",
      "epoch:4, step:100 loss:1.358099341392517\n",
      "epoch:4, step:200 loss:1.1399234533309937\n",
      "epoch:5, step:0 loss:1.064672827720642\n",
      "epoch:5, step:100 loss:1.315996527671814\n",
      "epoch:5, step:200 loss:1.071283221244812\n",
      "epoch:6, step:0 loss:1.001338243484497\n",
      "epoch:6, step:100 loss:0.7731111645698547\n",
      "epoch:6, step:200 loss:0.4953363537788391\n",
      "epoch:7, step:0 loss:0.6500086784362793\n",
      "epoch:7, step:100 loss:0.7191830277442932\n",
      "epoch:7, step:200 loss:0.5285062193870544\n",
      "epoch:8, step:0 loss:0.4401146173477173\n",
      "epoch:8, step:100 loss:0.500275731086731\n",
      "epoch:8, step:200 loss:0.3191242516040802\n",
      "epoch:9, step:0 loss:0.46588724851608276\n",
      "epoch:9, step:100 loss:0.5581982731819153\n",
      "epoch:9, step:200 loss:0.4840467870235443\n"
     ]
    }
   ],
   "source": [
    "# weights and bias\n",
    "w1 = tf.Variable(tf.random.uniform([28*28, 200]))\n",
    "b1 = tf.Variable(tf.zeros([200]))\n",
    "\n",
    "w2 = tf.Variable(tf.random.uniform([200, 100]))\n",
    "b2 = tf.Variable(tf.zeros([100]))\n",
    "\n",
    "w3 = tf.Variable(tf.random.uniform([100, 10]))\n",
    "b3 = tf.Variable(tf.zeros([10]))\n",
    "\n",
    "\n",
    "# forward func\n",
    "def model(x):\n",
    "    x = tf.nn.relu(x@w1 + b1)\n",
    "    x = tf.nn.relu(x@w2 + b2)\n",
    "    x = x@w3 + b3\n",
    "        \n",
    "    return x\n",
    "\n",
    "optimizer = tf.optimizers.Adam(learning_rate)\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    \n",
    "    for step, (x, y) in enumerate(ds_train):\n",
    "        x = tf.reshape(x, [-1, 28*28])\n",
    "        with tf.GradientTape() as tape:            \n",
    "            logits = model(x)\n",
    "            \n",
    "            losses = tf.losses.sparse_categorical_crossentropy(y,logits,from_logits=True)\n",
    "            loss = tf.reduce_mean(losses)\n",
    "            \n",
    "        grads = tape.gradient(loss, [w1,b1,w2,b2,w3,b3])\n",
    "        \n",
    "        optimizer.apply_gradients(zip(grads, [w1,b1,w2,b2,w3,b3]))\n",
    "        \n",
    "        if(step%100==0):\n",
    "            print(\"epoch:{}, step:{} loss:{}\".\n",
    "                  format(epoch, step, loss.numpy()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0, step:0, loss:9682.416015625\n",
      "epoch:0, step:100, loss:2.488150119781494\n",
      "epoch:0, step:200, loss:2.2699265480041504\n",
      "epoch:1, step:0, loss:2.274050235748291\n",
      "epoch:1, step:100, loss:2.34322452545166\n",
      "epoch:1, step:200, loss:2.2167744636535645\n",
      "epoch:2, step:0, loss:2.3210036754608154\n",
      "epoch:2, step:100, loss:2.2207939624786377\n",
      "epoch:2, step:200, loss:2.4887542724609375\n",
      "epoch:3, step:0, loss:2.344801187515259\n",
      "epoch:3, step:100, loss:2.1573853492736816\n",
      "epoch:3, step:200, loss:2.2278380393981934\n",
      "epoch:4, step:0, loss:2.2109551429748535\n",
      "epoch:4, step:100, loss:2.190887928009033\n",
      "epoch:4, step:200, loss:2.702714204788208\n",
      "epoch:5, step:0, loss:2.1471283435821533\n",
      "epoch:5, step:100, loss:2.527514696121216\n",
      "epoch:5, step:200, loss:2.131390333175659\n",
      "epoch:6, step:0, loss:2.4100725650787354\n",
      "epoch:6, step:100, loss:1.98695969581604\n",
      "epoch:6, step:200, loss:1.888046383857727\n",
      "epoch:7, step:0, loss:2.0443994998931885\n",
      "epoch:7, step:100, loss:1.8937968015670776\n",
      "epoch:7, step:200, loss:1.9226715564727783\n",
      "epoch:8, step:0, loss:1.7935198545455933\n",
      "epoch:8, step:100, loss:1.7859866619110107\n",
      "epoch:8, step:200, loss:1.7676191329956055\n",
      "epoch:9, step:0, loss:1.6434701681137085\n",
      "epoch:9, step:100, loss:1.6301932334899902\n",
      "epoch:9, step:200, loss:1.8712739944458008\n"
     ]
    }
   ],
   "source": [
    "# weights and bias\n",
    "w1 = torch.rand(28*28, 200 , requires_grad=True)\n",
    "b1 = torch.zeros(200, requires_grad=True)\n",
    "\n",
    "w2 = torch.rand(200, 100, requires_grad=True)\n",
    "b2 = torch.zeros(100, requires_grad=True)\n",
    "\n",
    "w3 = torch.rand(100, 10, requires_grad=True)\n",
    "b3 = torch.zeros(10, requires_grad=True)\n",
    "\n",
    "# torch.nn.init.kaiming_normal_(w1)\n",
    "# torch.nn.init.kaiming_normal_(w2)\n",
    "# torch.nn.init.kaiming_normal_(w3)\n",
    "\n",
    "# forward func\n",
    "def forward(x):\n",
    "    x = F.relu(x@w1 + b1)\n",
    "    x = F.relu(x@w2 + b2)\n",
    "    x = x@w3 + b3\n",
    "        \n",
    "    return x\n",
    "\n",
    "optimizer = torch.optim.Adam([w1,b1,w2,b2,w3,b3],\n",
    "                            lr=learning_rate)\n",
    "criteon = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    \n",
    "    for step, (x, y) in enumerate(train_loader):\n",
    "        x = x.reshape(-1,28*28)\n",
    "        \n",
    "        logits = forward(x)\n",
    "        loss = criteon(logits, y)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if(step%100 == 0):\n",
    "            print(\"epoch:{}, step:{}, loss:{}\".\n",
    "                  format(epoch, step, loss.item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Full-connnected higher level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0, step:0 loss:2.366990566253662\n",
      "accuracy:  0.4342\n",
      "epoch:0, step:100 loss:0.1408572494983673\n",
      "accuracy:  0.9444\n",
      "epoch:0, step:200 loss:0.1306784301996231\n",
      "accuracy:  0.9476\n",
      "epoch:1, step:0 loss:0.18211042881011963\n",
      "accuracy:  0.9523\n",
      "epoch:1, step:100 loss:0.1134713813662529\n",
      "accuracy:  0.9645\n",
      "epoch:1, step:200 loss:0.10413701087236404\n",
      "accuracy:  0.9562\n",
      "epoch:2, step:0 loss:0.10256875306367874\n",
      "accuracy:  0.9589\n",
      "epoch:2, step:100 loss:0.08379142731428146\n",
      "accuracy:  0.9636\n",
      "epoch:2, step:200 loss:0.06404948979616165\n",
      "accuracy:  0.9685\n",
      "epoch:3, step:0 loss:0.031297821551561356\n",
      "accuracy:  0.9668\n",
      "epoch:3, step:100 loss:0.06204662472009659\n",
      "accuracy:  0.9681\n",
      "epoch:3, step:200 loss:0.039909422397613525\n",
      "accuracy:  0.9734\n",
      "epoch:4, step:0 loss:0.12970837950706482\n",
      "accuracy:  0.9707\n",
      "epoch:4, step:100 loss:0.1105945035815239\n",
      "accuracy:  0.9647\n",
      "epoch:4, step:200 loss:0.1333925724029541\n",
      "accuracy:  0.969\n",
      "epoch:5, step:0 loss:0.0438968688249588\n",
      "accuracy:  0.9689\n",
      "epoch:5, step:100 loss:0.06427070498466492\n",
      "accuracy:  0.9751\n",
      "epoch:5, step:200 loss:0.04650873690843582\n",
      "accuracy:  0.9689\n",
      "epoch:6, step:0 loss:0.06343336403369904\n",
      "accuracy:  0.9697\n",
      "epoch:6, step:100 loss:0.07626428455114365\n",
      "accuracy:  0.974\n",
      "epoch:6, step:200 loss:0.055490318685770035\n",
      "accuracy:  0.9682\n",
      "epoch:7, step:0 loss:0.1507387012243271\n",
      "accuracy:  0.9664\n",
      "epoch:7, step:100 loss:0.09286125004291534\n",
      "accuracy:  0.9731\n",
      "epoch:7, step:200 loss:0.07143430411815643\n",
      "accuracy:  0.9689\n",
      "epoch:8, step:0 loss:0.08946146070957184\n",
      "accuracy:  0.9714\n",
      "epoch:8, step:100 loss:0.027022486552596092\n",
      "accuracy:  0.9722\n",
      "epoch:8, step:200 loss:0.0307907834649086\n",
      "accuracy:  0.9683\n",
      "epoch:9, step:0 loss:0.07391297072172165\n",
      "accuracy:  0.9693\n",
      "epoch:9, step:100 loss:0.08514115959405899\n",
      "accuracy:  0.9728\n",
      "epoch:9, step:200 loss:0.05861741676926613\n",
      "accuracy:  0.9687\n"
     ]
    }
   ],
   "source": [
    "class FC_model(keras.Model):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "        self.model = keras.Sequential(\n",
    "            [layers.Dense(200),\n",
    "            layers.ReLU(),\n",
    "            layers.Dense(100),\n",
    "            layers.ReLU(),\n",
    "            layers.Dense(10)]\n",
    "            )\n",
    "    \n",
    "    def call(self,x):\n",
    "        x = self.model(x)\n",
    "        \n",
    "        return x\n",
    "    \n",
    "model = FC_model()\n",
    "optimizer = tf.optimizers.Adam(learning_rate)\n",
    "    \n",
    "for epoch in range(epochs):\n",
    "    \n",
    "    for step, (x, y) in enumerate(ds_train):\n",
    "        x = tf.reshape(x, [-1, 28*28])\n",
    "        with tf.GradientTape() as tape:            \n",
    "            logits = model(x)\n",
    "            \n",
    "            losses = tf.losses.sparse_categorical_crossentropy(y,logits,from_logits=True)\n",
    "            loss = tf.reduce_mean(losses)\n",
    "            \n",
    "        grads = tape.gradient(loss, model.variables)\n",
    "        \n",
    "        optimizer.apply_gradients(zip(grads, model.variables))\n",
    "        \n",
    "        if(step%100==0):\n",
    "            print(\"epoch:{}, step:{} loss:{}\".\n",
    "                  format(epoch, step, loss.numpy()))\n",
    "            \n",
    "            \n",
    "#             test accuracy: \n",
    "            total_correct = 0\n",
    "            total_num = 0\n",
    "            \n",
    "            for x_test, y_test in ds_test:\n",
    "                x_test = tf.reshape(x_test, [-1, 28*28])\n",
    "                y_pred = tf.argmax(model(x_test),axis=1)\n",
    "                y_pred = tf.cast(y_pred, tf.int32)\n",
    "                correct = tf.cast((y_pred == y_test), tf.int32)\n",
    "                correct = tf.reduce_sum(correct)\n",
    "                \n",
    "                total_correct += int(correct)\n",
    "                total_num += x_test.shape[0]\n",
    "        \n",
    "            \n",
    "            accuracy = total_correct/total_num\n",
    "            print('accuracy: ', accuracy)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:0, step:0, loss:2.2986953258514404\n",
      "accuracy:  0.17790000140666962\n",
      "epoch:0, step:100, loss:0.17231173813343048\n",
      "accuracy:  0.9359999895095825\n",
      "epoch:0, step:200, loss:0.15071144700050354\n",
      "accuracy:  0.9472999572753906\n",
      "epoch:1, step:0, loss:0.07731062173843384\n",
      "accuracy:  0.961899995803833\n",
      "epoch:1, step:100, loss:0.19826166331768036\n",
      "accuracy:  0.9538999795913696\n",
      "epoch:1, step:200, loss:0.14187023043632507\n",
      "accuracy:  0.9587999582290649\n",
      "epoch:2, step:0, loss:0.0891067311167717\n",
      "accuracy:  0.9620999693870544\n",
      "epoch:2, step:100, loss:0.12733696401119232\n",
      "accuracy:  0.9612999558448792\n",
      "epoch:2, step:200, loss:0.11065959930419922\n",
      "accuracy:  0.9555999636650085\n",
      "epoch:3, step:0, loss:0.1741502583026886\n",
      "accuracy:  0.9610999822616577\n",
      "epoch:3, step:100, loss:0.06588414311408997\n",
      "accuracy:  0.9656999707221985\n",
      "epoch:3, step:200, loss:0.1473999172449112\n",
      "accuracy:  0.9538999795913696\n",
      "epoch:4, step:0, loss:0.1541697084903717\n",
      "accuracy:  0.9605000019073486\n",
      "epoch:4, step:100, loss:0.04024747014045715\n",
      "accuracy:  0.9623000025749207\n",
      "epoch:4, step:200, loss:0.13536378741264343\n",
      "accuracy:  0.9617999792098999\n",
      "epoch:5, step:0, loss:0.05285172164440155\n",
      "accuracy:  0.9599999785423279\n",
      "epoch:5, step:100, loss:0.07005582004785538\n",
      "accuracy:  0.9711999893188477\n",
      "epoch:5, step:200, loss:0.12366761267185211\n",
      "accuracy:  0.968999981880188\n",
      "epoch:6, step:0, loss:0.06925167143344879\n",
      "accuracy:  0.9702999591827393\n",
      "epoch:6, step:100, loss:0.067380391061306\n",
      "accuracy:  0.9673999547958374\n",
      "epoch:6, step:200, loss:0.08025708794593811\n",
      "accuracy:  0.9655999541282654\n",
      "epoch:7, step:0, loss:0.06105445697903633\n",
      "accuracy:  0.9613999724388123\n",
      "epoch:7, step:100, loss:0.0417051687836647\n",
      "accuracy:  0.9678999781608582\n",
      "epoch:7, step:200, loss:0.061699021607637405\n",
      "accuracy:  0.9688999652862549\n",
      "epoch:8, step:0, loss:0.07633427530527115\n",
      "accuracy:  0.9649999737739563\n",
      "epoch:8, step:100, loss:0.19862857460975647\n",
      "accuracy:  0.9708999991416931\n",
      "epoch:8, step:200, loss:0.07430338114500046\n",
      "accuracy:  0.9644999504089355\n",
      "epoch:9, step:0, loss:0.06198420375585556\n",
      "accuracy:  0.9674999713897705\n",
      "epoch:9, step:100, loss:0.03901655972003937\n",
      "accuracy:  0.9729999899864197\n",
      "epoch:9, step:200, loss:0.028583237901329994\n",
      "accuracy:  0.9659000039100647\n"
     ]
    }
   ],
   "source": [
    "class FC_NN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "        self.model = nn.Sequential(\n",
    "            nn.Linear(28*28, 200),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(200, 100),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(100,10)\n",
    "            )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.model(x)\n",
    "        \n",
    "        return x\n",
    "device = torch.device('cuda:0')\n",
    "\n",
    "network = FC_NN().to(device)        \n",
    "optimizer = torch.optim.Adam(network.parameters(),\n",
    "                            lr=learning_rate)\n",
    "criteon = torch.nn.CrossEntropyLoss().to(device)\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    \n",
    "    for step, (x, y) in enumerate(train_loader):\n",
    "        x = x.reshape(-1,28*28)\n",
    "        \n",
    "        x, y = x.to(device), y.to(device)\n",
    "        \n",
    "        logits = network(x)\n",
    "        loss = criteon(logits, y)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if(step%100 == 0):\n",
    "            print(\"epoch:{}, step:{}, loss:{}\".\n",
    "                  format(epoch, step, loss.item()))\n",
    "        \n",
    "#             test accuracy\n",
    "            total_correct = 0\n",
    "            total_num = 0    \n",
    "\n",
    "            for x_test, y_test in test_loader:\n",
    "                    x_test = x_test.reshape(-1,28*28)\n",
    "                    x_test, y_test = x_test.to(device), y_test.to(device)\n",
    "\n",
    "                    y_pred = network(x_test)\n",
    "                    y_pred = torch.argmax(y_pred, dim = 1)\n",
    "                    correct = y_pred == y_test\n",
    "                    correct = correct.sum()\n",
    "\n",
    "                    total_correct += correct\n",
    "                    total_num += x_test.shape[0]\n",
    "\n",
    "            acc = total_correct.float()/total_num\n",
    "            print(\"accuracy: \", acc.item())\n",
    "                \n",
    "                "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
